file_path = '/root/DSF/entities/cleaned_symptom_topN.txt'  # adjust file_path if necessary
symptom_data = open(file_path, 'r').readlines()

entity_ls = []
prompt_ls = []
topN_seq_ls = []

for line in symptom_data:
    line = line.strip()
    entity, example, topN_seq = line.split('\t')
    entity_ls.append(entity)
    prompt_ls.append(example)
    topN_seq_ls.append(topN_seq)  # remove the trailing newline character

# print the first 5 elements to check
for i in range(min(3, len(entity_ls))):
    print('Entity:', entity_ls[i])
    print('Example:', prompt_ls[i])
    print('TopN Seq:', topN_seq_ls[i], '\n')

format_few_shot1 = ''
format_few_shot1_response = ""

format_few_shot2 = ''
format_few_shot2_response = ""

format_few_shot3 = ''
format_few_shot3_response = ""

BATCH_SIZE = 4

batch_messages = []
num_steps = len(entity_ls) // BATCH_SIZE
residual = len(entity_ls) % BATCH_SIZE

# Process full batches
for i in range(num_steps):
    messages = []
    for j in range(BATCH_SIZE):
        index = i * BATCH_SIZE + j
        user_message = [
                {"role": "system", "content": 'As a medical expert, help me annotate standardized medical term text. I provide entity names, text where the entity appears, potential standardized medical terms. Only give us the medical terms without reason. help me select the most appropriate standardized text from the following data. If none are suitable, please answer with "no standardized noun"'},
                {"role": "user", "content": format_few_shot1},
                {"role": "system", "content": format_few_shot1_response},
                {"role": "user", "content": format_few_shot2},
                {"role": "system", "content": format_few_shot2_response},
                {"role": "user", "content": format_few_shot3},
                {"role": "system", "content": format_few_shot3_response},
                {"role": "user", "content": f"Extract Standard Name: Entity:{entity_ls[index]}, Example: {prompt_ls[index]}, Candidate: {topN_seq_ls[index]}"}
        ]
        messages.append(user_message)
    batch_messages.append(messages)

# Process the last batch with size smaller than BATCH_SIZE
if residual != 0:
    messages = []
    for j in range(residual):
        index = num_steps * BATCH_SIZE + j
        user_message = [
                {"role": "system", "content": 'As a medical expert, help me annotate standardized medical term text. I provide entity names, text where the entity appears, potential standardized medical terms. Only give us the medical terms without reason. help me select the most appropriate standardized text from the following data only answer "Standard Symptom Name:". If none are suitable, please answer with "No suitable text"'},
                {"role": "user", "content": format_few_shot1},
                {"role": "system", "content": format_few_shot1_response},
                {"role": "user", "content": format_few_shot2},
                {"role": "system", "content": format_few_shot2_response},
                {"role": "user", "content": format_few_shot3},
                {"role": "system", "content": format_few_shot3_response},
                {"role": "user", "content": f"Extract Standard Name: Entity:{entity_ls[index]}, Example: {prompt_ls[index]}, Candidate: {topN_seq_ls[index]}"}
        ]
        messages.append(user_message)
    batch_messages.append(messages)

from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" # the device to load the model onto

model = AutoModelForCausalLM.from_pretrained(
    "",
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("", padding_side='left')

output_responses = []

for messages in batch_messages:
    chat_template_messages = []
    for message in messages:
        text = tokenizer.apply_chat_template(
            message,
            tokenize=False,
            add_generation_prompt=True,
        )
        chat_template_messages.append(text)
    model_inputs = tokenizer(chat_template_messages, padding='longest',truncation=True, return_tensors="pt").to(device)
    
    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=128
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    output_responses.extend(response)

# Write to file
output_file_path = "symptom_responses.txt"
with open(output_file_path, 'w') as output_file:
    for entity, prompt, topN_seq, resp in zip(entity_ls, prompt_ls, topN_seq_ls, output_responses):
        resp = resp.strip().replace('\n', ' ').replace('\t', ' ')
        output_file.write(f'Entity: {entity}, Example: {prompt}, Candidate: {topN_seq}, Predicted Response: {resp}\n')

print("Responses have been written to", output_file_path)